#!/usr/bin/python3

import argparse
import gzip
import io
import json
import os
import requests
import socket
import sys
import time

support_test = [
    "/v1/eula",
]

support_global = [
    "/v1/system/license",
    "/v1/internal/system",
    "/v1/system/summary",
    "/v2/system/config",
    "/v1/system/rbac",
    "/v1/fed/member",
    "/v1/user_role",
    "/v1/user",
    "/v1/password_profile",
    "/v1/server",
    "/v1/server/_platform_/user",
    "/v1/partner/ibm_sa_config",
    "/v1/compliance/profile",
    "/v1/scan/config",
    "/v1/scan/status",
    "/v1/scan/scanner",
    "/v1/scan/platform",
    "/v1/scan/platform/platform",
    "/v1/scan/registry",
    "/v1/admission/state",
    "/v1/admission/rules",

    "/v1/domain",
    "/v1/service",
    "/v1/group",
    "/v1/custom_check",
    "/v1/process_profile",
    "/v1/file_monitor",
    "/v1/dlp/sensor",
    "/v1/dlp/group",
    "/v1/waf/sensor",
    "/v1/waf/group",
    "/v1/policy/rule",
    "/v1/response/rule",
    "/v1/conversation",

    "/v1/log/activity",
    "/v1/log/event",
    "/v1/log/incident",
    "/v1/log/threat",
    "/v1/log/violation",
    "/v1/log/audit",

    "/v1/debug/internal_subnets",
    "/v1/debug/ip2workload",
    "/v1/debug/controller/sync",
    "/v1/debug/system/stats",
    "/v1/debug/registry/image/openshift",
    "/v1/debug/admission_stats",
]

support_host = [
    "/v1/host",
    "/v1/debug/ip2workload?f_node=%s",
]

support_controller = [
    "/v1/controller",
    "/v1/controller/%s/counter",
    # "/v1/controller/%s/logs?start=0&limit=100000", # 1000 lines * 100
    # "/v1/controller/%s/logs?start=-1&limit=100000", # 1000 lines * 100
    # "/v1/controller/%s/config",
]

support_enforcer = [
    "/v1/enforcer",
    # "/v1/enforcer/%s/config",
    "/v1/enforcer/%s/counter",
    "/v1/enforcer/%s/stats",
    "/v1/enforcer/%s/probe_summary",
    "/v1/debug/policy/rule?f_enforcer=%s",
    "/v1/session?f_enforcer=%s",
    "/v1/debug/dlp/wlrule?f_enforcer=%s",
    "/v1/debug/dlp/rule?f_enforcer=%s",
    # "/v1/enforcer/%s/logs?start=0&limit=100000", # 1000 lines * 100
    # "/v1/enforcer/%s/logs?start=-1&limit=100000", # 1000 lines * 100
]

# Very expensive to call every workload one by one
support_workload = [
    "/v2/workload?view=pod",
    # "/v1/workload/%s/process_history",
]

ENV_CTRL_SERVER_IP = "CTRL_SERVER_IP"
ENV_CTRL_SERVER_PORT = "CTRL_SERVER_PORT"

class RestException(Exception):
    message = "An unknown exception occurred."

    def __init__(self, **kwargs):
        try:
            super(RestException, self).__init__(self.message % kwargs)
            self.msg = self.message % kwargs
        except Exception:
            super(RestException, self).__init__(self.message)

class ConnectionError(RestException):
    message = "Fail to connect to server"

class ConnectTimeout(RestException):
    message = "Connection timeout"

class RequestError(RestException):
    message = "Request error"

class RestClient(object):

    def __init__(self, server, port, token, cookie):
        if server:
            self.server = server
        elif ENV_CTRL_SERVER_IP in os.environ:
            self.server = os.environ.get(ENV_CTRL_SERVER_IP)
        else: 
            self.server = "127.0.0.1"

        if port:
            self.port = port
        elif ENV_CTRL_SERVER_PORT in os.environ:
            self.port = os.environ.get(ENV_CTRL_SERVER_PORT)
        else:
            self.port = 10443

        self.url = "https://%s:%s" % (self.server, self.port)

        self.sess = requests.Session()
        self.sess.headers.update({"Content-Type": "application/json"})
        self.sess.headers.update({"X-Auth-Token": token})
        if cookie is not None:
            self.sess.headers.update({"X-R-Sess": cookie})

        try:
            if hasattr(requests.packages.urllib3, "disable_warnings"):
                requests.packages.urllib3.disable_warnings()
        except AttributeError:
            pass


    def request(self, method, path, body=None, decode=True):
        try:
            resp = self.sess.request(method, "%s%s" % (self.url, path), data=body, verify=False)
        except requests.exceptions.ConnectionError:
            raise ConnectionError()
        except requests.exceptions.Timeout:
            raise ConnectTimeout()
        except requests.exceptions.RequestException:
            raise RequestError()

        try:
            if decode:
                data = resp.json()
                raw = False
            else:
                data = resp
                raw = True
        except Exception as e:
            data = resp
            raw = True

        return resp.status_code, resp.text, data, raw

def support_request(path, output):
    output.write('"%s":\n' % path)

    if path.find('?') == -1:
        path = "{}?support=true".format(path)
    else:
        path = "{}&support=true".format(path)
    _, _, data, raw = client.request("GET", path)
    if raw:
        output.write(data.content)
    else:
        output.write(json.dumps(data, indent=4, sort_keys=True))
    output.write(",\n")
    return data

def support(client, jointID, enforcers, output):
    output.write('{')
    output.write('"Manager": "%s",\n' % socket.gethostname())
    output.write('"Time": "%s",\n' % time.strftime("%c"))

    # Test
    for _, _path in enumerate(support_test):
        path = _path
        if jointID:
            path = "/v1/fed/cluster/{}{}".format(jointID, _path)
        output.write('"%s": {},\n' % path)
        status, _, data, _ = client.request("GET", path)
        if status != requests.codes.ok:
            output.write(json.dumps(data, indent=4, sort_keys=True) + "\n")
            return status

    # Global
    for _, _path in enumerate(support_global):
        path = _path
        if jointID:
            path = "/v1/fed/cluster/{}{}".format(jointID, _path)
        support_request(path, output)

    # Host
    path = support_host[0]
    if jointID:
        path = "/v1/fed/cluster/{}{}".format(jointID, path)
    data = support_request(path, output)
    entries = data.get("hosts")
    if len(entries):
        ids = [entry["id"] for entry in entries]
        for _, _path in enumerate(support_host[1:]):
            path = _path
            if jointID:
                path = "/v1/fed/cluster/{}{}".format(jointID, _path)
            for id in ids:
                support_request(path % id, output)

    # Controller
    path = support_controller[0]
    if jointID:
        path = "/v1/fed/cluster/{}{}".format(jointID, path)
    data = support_request(path, output)
    entries = data.get("controllers")
    if len(entries):
        ids = [entry["id"] for entry in entries]
        for _, _path in enumerate(support_controller[1:]):
            path = _path
            if jointID:
                path = "/v1/fed/cluster/{}{}".format(jointID, _path)
            for id in ids:
                support_request(path % id, output)

    # Enforcer
    path = support_enforcer[0]
    if jointID:
        path = "/v1/fed/cluster/{}{}".format(jointID, path)
    data = support_request(path, output)
    entries = data.get("enforcers")
    if enforcers and len(entries) > 0:
        all = [entry["id"] for entry in entries]
        ids = []

        if enforcers == "ALL":
            ids = all
        else:
            enfs = enforcers.split(",")
            if len(enfs) > 0:
                for _, id in enumerate(all):
                    for _, e in enumerate(enfs):
                        if id.startswith(e):
                            ids.append(id)
                            break

        for _, _path in enumerate(support_enforcer[1:]):
            path = _path
            if jointID:
                path = "/v1/fed/cluster/{}{}".format(jointID, _path)

            for id in ids:
                support_request(path % id, output)

    # Workload
    path = support_workload[0]
    if jointID:
        path = "/v1/fed/cluster/{}{}".format(jointID, path)
    data = support_request(path, output)
    entries = data.get("workloads")
    if len(entries):
        ids = [entry["brief"]["id"] for entry in entries]
        for _, _path in enumerate(support_workload[1:]):
            path = _path
            if jointID:
                path = "/v1/fed/cluster/{}{}".format(jointID, _path)
            for id in ids:
                support_request(path % id, output)

    # add a dummy entry to correct the trailing comma
    output.write('"END":{}\n')
    output.write('}')
    return 0

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='NeuVector Support Script.')
    parser.add_argument('-s', '--server', help='controller IP address.')
    parser.add_argument('-p', '--port', type=int, help='controller port.')
    parser.add_argument('-t', '--token', help='login token.')
    parser.add_argument('-r', '--ress', help='Rancher cookie.')
    parser.add_argument('-o', '--output', help='output filename.')
    parser.add_argument('-e', '--enforcers', help='comma separated enforcer ids.') # if not specified means all enforcers
    parser.add_argument('-j', '--joint', help='remote joint cluster id.') # if not specified, meaning get local support logs
    args = parser.parse_args()

    client = RestClient(args.server, args.port, args.token, args.ress)

    rc = 0
    if args.output:
        (path, name) =  os.path.split(args.output)
        if not name:
            sys.exit(-1)

        tmpfile = os.path.join(path, "." + name)
        with gzip.open(tmpfile, 'wb') as output:
            with io.TextIOWrapper(output, encoding='utf-8') as encode:
                rc = support(client,  args.joint, args.enforcers, encode)
        os.rename(tmpfile, args.output)
    else:
        rc = support(client, args.joint, args.enforcers, sys.stdout)

    sys.exit(rc)
